import torch
import numpy as np
# import contextlib
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
# import numpy as np


class XentEC(nn.Module):

    def __init__(self, args):
        super(XentEC, self).__init__()
        self.n_cls = args.n_cls
        self.pos_fn = [torch.abs, torch.square, self.abs_smooth][args.i_pos_fn]
        self.neg = [self.neg1, self.neg2][args.i_neg_fn]

    def forward(self, logit, cossim, target):
        pos_term = torch.gather(logit, 1, target[:, None])
        neg_term = self.neg(logit, cossim)
        loss = - (pos_term - neg_term)
        return loss.mean()

    def neg1(self, logit, cossim):
        pos_cossim = self.pos_fn(cossim)
        neg_term = torch.logsumexp(pos_cossim, dim=1, keepdim=True)
        return neg_term

    def neg2(self, logit, cossim):
        pos_cossim= self.pos_fn(cossim)
        with torch.no_grad():
            p = F.softmax(logit, dim=1).detach()
        neg_term = p * pos_cossim
        return neg_term

    def abs_smooth(self, logit, beta = 1.0):
        mask = (torch.abs(logit) < beta)
        sq_logit = mask * (0.5*torch.square(logit) / beta)
        abs_logit = (~mask) * (torch.abs(logit) - 0.5*beta)
        sm_logit = sq_logit + abs_logit
        return sm_logit


class NegEC(nn.Module):

    def __init__(self, args):
        super(NegEC, self).__init__()
        self.pos_fn = [torch.abs, torch.square, self.abs_smooth][args.i_pos_fn]
        self.neg = [self.neg1, self.neg2][args.i_neg_fn]

    def forward(self, logit):
        return self.neg(logit)

    def neg1(self, logit):
        pos_logit= self.pos_fn(logit)
        neg_term = torch.logsumexp(pos_logit, dim=1, keepdim=True)
        return neg_term

    def neg2(self, logit):
        pos_logit= self.pos_fn(logit)
        with torch.no_grad():
            p = F.softmax(logit, dim=1).detach()
        neg_term = p * pos_logit
        return neg_term

    def abs_smooth(self, logit, beta = 1.0):
        mask = (torch.abs(logit) < beta)
        sq_logit = mask * (0.5*torch.square(logit) / beta)
        abs_logit = (~mask) * (torch.abs(logit) - 0.5*beta)
        sm_logit = sq_logit + abs_logit
        return sm_logit




# criterion = nn.CrossEntropyLoss()
def mixup_criterion(criterion, pred, y_a, y_b, lam, cond=True):
    if cond:
        loss = lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
    else:
        loss = criterion(pred, y_a, lam.view(-1, 1)) + criterion(pred, y_b, (1-lam).view(-1, 1))
    return loss


@torch.no_grad()
def mixup_data_batch_lam(args, x, y, index=None):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    # if alpha > 0:
        # lam = np.random.beta(alpha, alpha)
    # else:
    #     lam = 1
    n_cls, alpha = args.num_labels, args.alpha
    # TODO 기존 구현에서는 lam는 batch내에서 통일
    bsz = x.size()[0]
    lam = torch.Tensor(np.random.beta(alpha, alpha, bsz)).cuda()
    lam_x = lam.view(-1, 1, 1, 1)
    lam_y = lam.view(-1, 1)
    index = torch.randperm(bsz).cuda() if index is None else index
    x1, x2 = x, (x[index, :]).squeeze(0)
    if y.dtype == torch.long:
        y1, y2 = onehot(y, n_cls), onehot(y[index,:], n_cls).squeeze(0)
    else :
        y1, y2 = y, (y[index, :]).squeeze(0)
    mixed_x = (lam_x * x1 + (1 - lam_x) * x2).squeeze(0)
    mixed_y = (lam_y * y1 + (1 - lam_y) * y2).squeeze(0)
    return mixed_x, mixed_y, y1, y2, lam, index


@torch.no_grad()
def mixup_data_single_lam(args, x, y, index=None):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    # if alpha > 0:
        # lam = np.random.beta(alpha, alpha)
    # else:
    #     lam = 1
    n_cls, alpha = args.num_labels, args.alpha
    # TODO 기존 구현에서는 lam는 batch내에서 통일
    bsz = x.size()[0]
    # lam = torch.Tensor(np.random.beta(alpha, alpha)).cuda()
    lam = np.random.beta(alpha, alpha)
    # lam_x = lam.view(-1, 1, 1, 1)
    # lam_y = lam.view(-1, 1)
    index = torch.randperm(bsz).cuda() if index is None else index
    x1, x2 = x, (x[index, :]).squeeze(0)
    if y.dtype == torch.long:
        y1, y2 = onehot(y, n_cls), onehot(y[index,:], n_cls).squeeze(0)
    else :
        y1, y2 = y, (y[index, :]).squeeze(0)
    mixed_x = (lam * x1 + (1 - lam) * x2).squeeze(0)
    mixed_y = (lam * y1 + (1 - lam) * y2).squeeze(0)
    return mixed_x, mixed_y, y1, y2, lam, index


def get_adv_x(model, x_natural, step_size=0.003, epsilon=0.031, perturb_steps=10,
              adversarial=True, distance='l_inf',):
    with torch.autograd.set_detect_anomaly(True):
        # define KL-loss
        criterion_kl = nn.KLDivLoss(reduction='sum')
        model.eval()  # moving to eval mode to freeze batchnorm stats
        # generate adversarial example
        x_adv = x_natural.detach() + 0.  # the + 0. is for copying the tensor
        if adversarial:
            if distance == 'l_inf':
                x_adv += 0.001 * torch.randn(x_natural.shape).cuda().detach()
                for i_step in range(perturb_steps):
                    x_adv.requires_grad_()
                    with torch.enable_grad():
                        loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1),
                                            F.softmax(model(x_natural), dim=1))
                    grad = torch.autograd.grad(loss_kl, [x_adv])[0]
                    x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
                    x_adv = torch.min(torch.max(x_adv, x_natural - epsilon),
                                    x_natural + epsilon)
                    x_adv = torch.clamp(x_adv, 0.0, 1.0)
            else:
                raise ValueError('No support for distance %s in adversarial '
                                'training' % distance)
        else:
            if distance == 'l_2':
                x_adv = x_adv + epsilon * torch.randn_like(x_adv)
            else:
                raise ValueError('No support for distance %s in stability '
                                'training' % distance)
        model.train()  # moving to train mode to update batchnorm stats
        x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)
        return x_adv


def get_calibrated_target_adv_x(ec, tc_net, model, x_natural, step_size=0.003, \
          epsilon=0.031, perturb_steps=10, adversarial=True, distance='l_inf',):
    # define KL-loss
    criterion_kl = nn.KLDivLoss(reduction='sum')
    model.eval()  # moving to eval mode to freeze batchnorm stats
    # generate adversarial example
    logits_nat_tc = tc_net.module.forward_cls(x_natural)
    x_adv = x_natural.detach() + 0.  # the + 0. is for copying the tensor
    if adversarial:
        if distance == 'l_inf':
            x_adv += 0.001 * torch.randn(x_natural.shape).cuda().detach()
            for i_step in range(perturb_steps):
                x_adv.requires_grad_()
                with torch.no_grad():
                    ptb = x_adv - x_natural
                    target_p = ec.target_cond_p_st(logits_nat_tc, ptb).detach()
                with torch.enable_grad():
                    adv_log_p = F.log_softmax(model.module.forward_cls(x_adv), dim=1)
                    loss_kl = criterion_kl(adv_log_p, target_p)
                grad = torch.autograd.grad(loss_kl, [x_adv])[0]
                x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
                x_adv = torch.min(torch.max(x_adv, x_natural - epsilon),
                                x_natural + epsilon)
                x_adv = torch.clamp(x_adv, 0.0, 1.0)
        else:
            if distance == 'l_2':
                x_adv = x_adv + epsilon * torch.randn_like(x_adv)
            else:
                raise ValueError('No support for distance %s in stability '
                                'training' % distance)
        model.train()  # moving to train mode to update batchnorm stats
        x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)
        return x_adv


def xent_float_target(logit, target, n_cls, coef=1.0):
    if target.dtype == torch.long :
        target = F.one_hot(target, num_classes=n_cls).float()
    assert logit.size() == target.size(), (logit.shape, target.shape)
    estimate = torch.nn.functional.log_softmax(logit, dim=1)
    # return  -(coef * target * estimate).sum() / estimate.shape[0]
    return - (target * estimate).mean()


def xmixup_data(args, x, y, index=None):
    n_cls, alpha = args.num_labels, args.alpha
    bsz = x.size()[0]
    index = torch.randperm(bsz).cuda() if index is None else index
    x_org1, x_org2 = x, (x[index, :]).squeeze(0)
    x = noisy_x(x, args.epsilon)
    x1, x2 = x, (x[index, :]).squeeze(0)
    ptb = x2 - x1

    p = np.random.beta(alpha, alpha)
    b = np.random.binomial(p, 1)
    x_selected = [x1, x2][b]
    i = sample_intensity(x)
    x_sample = x_selected + args.epsilon * i.view(-1, 1, 1, 1) * (ptb/2)
    ## clipping

    if y.dtype == torch.long:
        y1, y2 = onehot(y, n_cls), onehot(y[index,:], n_cls).squeeze(0)
    else :
        y1, y2 = y, (y[index, :]).squeeze(0)

    norm1 = torch.norm((x_org1 - x_sample).view(len(x), -1), dim=-1)
    norm2 = torch.norm((x_org2 - x_sample).view(len(x), -1), dim=-1)
    lam = norm2/(norm1+norm2)
    return x_sample, None, y1, y2, lam, index


def noisy_x(x, coef=1.):
    d = sample_direction(x)
    i = coef * sample_intensity(x).view(-1, 1, 1, 1)
    return x+i*d


def noisy_x_ptb(args, x, ptb, coef=1.):
    i = coef * sample_intensity(x)
    import ipdb; ipdb.set_trace()
    return x+i*ptb


def sample_direction(x):
    xsh = x.shape
    d = torch.randn_like(x.view(xsh[0], -1)).cuda()
    d_norm = torch.norm(d, dim=-1).unsqueeze(-1)
    unit_d = d/d_norm
    return unit_d.reshape(xsh)


def sample_intensity(x):
    return torch.randn(x.shape[0]).cuda()


@torch.no_grad()
def onehot(targets, num_classes):
    """Origin: https://github.com/moskomule/mixup.pytorch
    convert index tensor into onehot tensor
    :param targets: index tensor
    :param num_classes: number of classes
    """
    oh = torch.zeros(targets.size()[0], num_classes).cuda().scatter_(1, targets.view(-1, 1), 1)
    return oh


def entropy_loss(unlabeled_logits):
    unlabeled_probs = F.softmax(unlabeled_logits, dim=1)
    return -(unlabeled_probs * F.log_softmax(unlabeled_logits, dim=1)).sum(
        dim=1).mean(dim=0)


def noise_loss(model,
               x_natural,
               y,
               epsilon=0.25,
               clamp_x=True):
    """Augmenting the input with random noise as in Cohen et al."""
    # logits_natural = model(x_natural)
    x_noise = x_natural + epsilon * torch.randn_like(x_natural)
    if clamp_x:
        x_noise = x_noise.clamp(0.0, 1.0)
    logits_noise = model(x_noise)
    loss = F.cross_entropy(logits_noise, y, ignore_index=-1)
    return loss


class MovingAvg:

    def __init__(self, args):
        self.ma = torch.zeros(1).cuda()
        self.ma_momentum = args.ma_momentum
        self.t = 1

    def __call__(self):
        return self.ma

    # @torch.no_grad()
    # def update(self, e):
        # """
        # Update ma used for teacher output.
        # """
        # batch_ma = torch.sum(e, dim=0, keepdim=True)
        # dist.all_reduce(batch_ma)
        # # batch_ma = batch_ma / (len(e) * dist.get_world_size())
        # # ema update
    #     self.ma = self.ma * self.ma_momentum + batch_ma * (1 - self.ma_momentum)

    @torch.no_grad()
    def update(self, e):
        """
        Update ma used for teacher output.
        """
        batch_ma = torch.sum(e, dim=0, keepdim=True)
        # dist.all_reduce(batch_ma)
        # batch_ma = batch_ma / (len(e) * dist.get_world_size())
        # batch_ma = e.mean()
        # ema update
        self.ma = (self.ma * self.ma_momentum) + batch_ma * (1 - self.ma_momentum)
        # correction = (1 - math.pow(self.ma_momentum, self.t))
        # # self.ma/= correction
        # print(batch_ma.data, self.ma.data, correction)
        # self.t += 1


